Skip to content

PR 4 of #508 — InferenceLayer ABC + SingleInstanceLayer (#512)#534

Merged
gitttt-1234 merged 2 commits into
divya/inf-refactor-03-torch-backendfrom
divya/inf-refactor-04-inference-layer-single
May 28, 2026
Merged

PR 4 of #508 — InferenceLayer ABC + SingleInstanceLayer (#512)#534
gitttt-1234 merged 2 commits into
divya/inf-refactor-03-torch-backendfrom
divya/inf-refactor-04-inference-layer-single

Conversation

@gitttt-1234

Copy link
Copy Markdown
Collaborator

Summary

Closes #512. Lock down the InferenceLayer abstraction with the proof-of-pattern single-instance subclass — and prove parity end-to-end against the PR 0 golden. Every subsequent layer (PR 6) and the Predictor orchestrator (PR 8) follows this template.

What lands

  • sleap_nn/inference/layers/configs.pyattrs.frozen PreprocessConfig + PostprocessConfig
  • sleap_nn/inference/layers/base.pyInferenceLayer ABC: abstract preprocess / postprocess, concrete predict / __call__ / warmup. Includes _to_4d_float_tensor helper that accepts (H,W), (H,W,C), (C,H,W), (B,H,W,C), (B,C,H,W) numpy or torch — gives every subclass a uniform input contract for the new direct-numpy API
  • sleap_nn/inference/layers/single_instance.pySingleInstanceLayer. Decodes confmaps via ops.peaks.find_global_peaks, applies the full coord ladder via ops.coord, returns Outputs

Headline result

End-to-end parity proven by test_single_instance_layer_parity_vs_pr0_golden:

new_outputs = SingleInstanceLayer(...).predict(image_4d)
np.testing.assert_allclose(
    new_outputs.pred_keypoints.squeeze(1).numpy(),
    golden[\"pred_instance_peaks\"],
    atol=1e-5, rtol=1e-5,
)  # PASSES

The new layer stack — np.ndarray in, Outputs dataclass out — matches the captured-from-old-code golden within 1e-5 atol/rtol on the same fixed input. Linchpin of the refactor: as long as this passes, every subsequent PR (5–14) is verifiably parity-preserving on the single-instance path.

Test plan

tests/inference/layers/ — 22 cases:

  • InferenceLayer ABC enforcement (cannot instantiate directly; rejects non-ModelBackend)
  • _to_4d_float_tensor: 5 input shapes parametrized + numpy/torch + rejection of unsupported types/ranks
  • Parity vs PR 0 golden within 1e-5
  • Direct numpy + torch APIs
  • 2D grayscale + 4D batched inputs
  • return_confmaps off-by-default + opt-in populates Outputs.pred_confmaps
  • Synthetic-input coord ladder verification

SingleInstanceLayer.from_checkpoint(...) is deferred to PR 8 (#516) where checkpoint-load logic consolidates. For now the constructor takes an already-built ModelBackend; tests build via the existing Predictor.from_model_paths and pull modules out of inference_model.

Full inference suite: 240 passed, 8 skipped, 2 xfailed.
Parity regen: 30/30 specs match.

🤖 Generated with Claude Code

@codecov

codecov Bot commented May 1, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.07297% with 427 lines in your changes missing coverage. Please review.
✅ Project coverage is 61.66%. Comparing base (00f2d64) to head (33b0c52).

Files with missing lines Patch % Lines
sleap_nn/inference/layers/exported.py 0.00% 65 Missing ⚠️
...p_nn/inference/layers/backends/tensorrt_backend.py 25.88% 63 Missing ⚠️
sleap_nn/inference/factory.py 59.15% 58 Missing ⚠️
sleap_nn/inference/layers/centroid.py 62.38% 41 Missing ⚠️
sleap_nn/inference/layers/backends/onnx_backend.py 37.50% 40 Missing ⚠️
sleap_nn/inference/predictor.py 73.77% 32 Missing ⚠️
sleap_nn/inference/layers/topdown_multiclass.py 47.27% 29 Missing ⚠️
sleap_nn/cli.py 85.71% 23 Missing ⚠️
sleap_nn/inference/streaming.py 88.57% 12 Missing ⚠️
sleap_nn/inference/filters.py 92.30% 11 Missing ⚠️
... and 8 more
Additional details and impacted files
@@                           Coverage Diff                           @@
##           divya/inf-refactor-03-torch-backend     #534      +/-   ##
=======================================================================
- Coverage                                63.65%   61.66%   -1.99%     
=======================================================================
  Files                                      109      129      +20     
  Lines                                    17918    19607    +1689     
=======================================================================
+ Hits                                     11406    12091     +685     
- Misses                                    6512     7516    +1004     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gitttt-1234 gitttt-1234 force-pushed the divya/inf-refactor-03-torch-backend branch from 1427454 to 90758f0 Compare May 1, 2026 16:42
@gitttt-1234 gitttt-1234 force-pushed the divya/inf-refactor-04-inference-layer-single branch from aae919a to 5ba3150 Compare May 1, 2026 16:42
…#508)

Lock down the InferenceLayer abstraction with the proof-of-pattern
single-instance subclass — and prove parity end-to-end against the PR 0
golden. Every subsequent layer (PR 6) and the Predictor orchestrator
(PR 8) follows this template.

Layout:

  sleap_nn/inference/layers/
    configs.py            attrs.frozen PreprocessConfig + PostprocessConfig
    base.py               InferenceLayer ABC: preprocess / postprocess /
                          predict / __call__ / warmup. Includes the
                          _to_4d_float_tensor helper that accepts (H,W),
                          (H,W,C), (C,H,W), (B,H,W,C), (B,C,H,W) numpy or
                          torch — gives every subclass a uniform input
                          contract for the new direct-numpy API.
    single_instance.py    SingleInstanceLayer (concrete). Decodes confmaps
                          via ops.peaks.find_global_peaks, applies the
                          full coord ladder via ops.coord, returns Outputs.

End-to-end parity proven by ``test_single_instance_layer_parity_vs_pr0_golden``:
the new layer's pred_keypoints / pred_peak_values match the PR 0 golden
within 1e-5 atol/rtol on the captured fixed input. This test is the
linchpin of the entire refactor — every PR 5–14 will keep gating on it.

Tests (`tests/inference/layers/`, 22 cases):

* InferenceLayer ABC enforcement: cannot instantiate directly; rejects
  non-ModelBackend; predict / __call__ agreement.
* _to_4d_float_tensor: 5 input shapes parametrized + numpy/torch +
  rejection of unsupported types and ranks.
* SingleInstanceLayer:
  - parity vs PR 0 golden (within 1e-5)
  - direct numpy API
  - direct torch API
  - 2D grayscale + 4D batched inputs
  - return_confmaps off-by-default + opt-in populates Outputs.pred_confmaps
  - synthetic-input coord ladder verification (output_stride scaling)

Full inference suite: 240 passed, 8 skipped, 2 xfailed.
Parity regen (RUN_GOLDEN_REGEN_CHECK=1): 30/30 specs match byte-for-byte.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gitttt-1234 gitttt-1234 force-pushed the divya/inf-refactor-03-torch-backend branch from 90758f0 to 00f2d64 Compare May 1, 2026 19:52
@gitttt-1234 gitttt-1234 force-pushed the divya/inf-refactor-04-inference-layer-single branch from 5ba3150 to ab68fca Compare May 1, 2026 19:52
…513) (#535)

## Summary

Closes #513. Three ops in `sleap_nn/inference/ops` get rewritten to
remove constructs the legacy TorchScript ONNX exporter rejected, while
preserving prior behavior bit-exactly (or within 1 ULP).

## Op rewrites

- **`morphological_dilation`** — replace `Tensor.unfold` with an
explicit 8-shift max stack. Drops the `xfail(strict=True)` marker on its
ONNX-export smoke test from PR 1; that test now passes.
`find_local_peaks_rough` gains exportability as a side effect (was
previously documented as a \"known gap\" expecting a raise; now it
succeeds).

- **`find_global_peaks_rough`** — replace boolean-mask `index_put_`
in-place assignment with `torch.where`. Also splits a single
`squeeze(dim=(2, 3))` into two single-dim `squeeze` calls (the legacy
exporter does not lower the multi-dim form). PR 1's `xfail(strict=True)`
marker dropped; smoke test now passes.

- **`crop_bboxes`** — replace per-peak `unfold` +
advanced-indexing-on-unfolded-view with direct advanced indexing on a
zero-padded image. Bbox top-lefts are floored before extraction
(matching the prior `.to(torch.long)` truncation), so integer-aligned
bboxes produce bit-exact crops; sub-pixel bboxes (centroid-driven
top-down stage 2) reproduce the old \"snap to integer pixel\" behavior
exactly. Avoids the bilinear-interp drift that `F.grid_sample` would
introduce while still removing `unfold`.

## Tests updated

- Drop the two PR 1 `xfail(strict=True)` markers (auto-flipped to
passing)
- Invert `find_local_peaks_rough_known_export_gap` — the function now
exports for fixed-shape examples; test asserts that and verifies output
parity. Variable-peak-count output remains a runtime-shape constraint
that PR 7 (#515) addresses via `find_top_k_peaks`
- Bump `test_golden_is_reproducible` float tolerance from strict zero to
`atol=1e-5`, `rtol=1e-6` to absorb 1-ULP drift from the `torch.where`
rewrite. Two orders of magnitude tighter than the design-doc budget
(1e-4 / 1e-5); integer fields still compared exactly

## CUDA test suite

Bonus: `tests/inference/test_cuda.py` — 12 module-level-skipif-gated
tests covering pure ops, `Outputs` device transfer,
`TorchBackend(device='cuda')`, `SingleInstanceLayer` cross-device
parity, FP16 drift budget, pin_memory transfer correctness. Skip cleanly
on non-CUDA hosts; on a CUDA box run with:

```bash
pytest tests/inference/test_cuda.py -v
```

## Test plan

- [x] `RUN_GOLDEN_REGEN_CHECK=1` — 30/30 specs reproduce within 1e-5 of
PR 0 goldens (subprocess-isolated capture)
- [x] Full inference suite: 242 passed, 8 skipped (CUDA-only) + 12
skipped (CUDA suite), 0 failed, 0 xfailed
- [x] PR 4's `SingleInstanceLayer` parity test still passes at 1e-5
atol/rtol
- [x] Both ONNX-export `xfail` markers from PR 1 dropped — corresponding
tests now pass cleanly

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gitttt-1234 gitttt-1234 marked this pull request as ready for review May 28, 2026 18:14
@gitttt-1234 gitttt-1234 merged commit a4120fd into divya/inf-refactor-03-torch-backend May 28, 2026
@gitttt-1234 gitttt-1234 deleted the divya/inf-refactor-04-inference-layer-single branch May 28, 2026 18:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant